{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(sharded-computation)=\n",
    "# Introduction to parallel programming\n",
    "\n",
    "<!--* freshness: { reviewed: '2024-05-10' } *-->\n",
    "\n",
    "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n",
    "\n",
    "The tutorial covers three modes of parallel computation:\n",
    "\n",
    "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n",
    "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n",
    "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n",
    "\n",
    "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n",
    "\n",
    "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
       " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
       " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
       " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
       " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
       " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
       " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
       " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Key concept: Data sharding\n",
    "\n",
    "Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n",
    "\n",
    "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX.  The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
    "\n",
    "In the simplest cases, arrays are sharded on a single device, as demonstrated below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "outputId": "39fdbb79-d5c0-4ea6-8b20-88b2c502a27a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "arr = jnp.arange(32.0).reshape(4, 8)\n",
    "arr.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "outputId": "536f773a-7ef4-4526-c58b-ab4d486bf5a1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr.sharding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a more visual representation of the storage layout, the {mod}`jax.debug` module provides some helpers to visualize the sharding of an array. For example, {func}`jax.debug.visualize_array_sharding` displays how the array is stored in memory of a single device:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "outputId": "74a793e9-b13b-4d07-d8ec-7e25c547036d"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                      TPU 0                       </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">                                                  </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                      \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m                       \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m                                                  \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "jax.debug.visualize_array_sharding(arr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To create an array with a non-trivial sharding, you can define a {mod}`jax.sharding` specification for the array and pass this to {func}`jax.device_put`.\n",
    "\n",
    "Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimensional grid of devices with named axes, where {class}`jax.sharding.Mesh` allows for precise device placement:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "outputId": "0b397dba-3ddc-4aca-f002-2beab7e6b8a5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))\n"
     ]
    }
   ],
   "source": [
    "from jax.sharding import PartitionSpec as P\n",
    "\n",
    "mesh = jax.make_mesh((2, 4), ('x', 'y'))\n",
    "sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n",
    "print(sharding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Passing this `Sharding` object to {func}`jax.device_put`, you can obtain a sharded array:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "outputId": "c8ceedba-05ca-4156-e6e4-1e98bb664a66"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1.  2.  3.  4.  5.  6.  7.]\n",
      " [ 8.  9. 10. 11. 12. 13. 14. 15.]\n",
      " [16. 17. 18. 19. 20. 21. 22. 23.]\n",
      " [24. 25. 26. 27. 28. 29. 30. 31.]]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">            </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">            </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">   TPU 0    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">   TPU 1    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">   TPU 2    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">   TPU 3    </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">            </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">            </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">            </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">            </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">            </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">   TPU 6    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">   TPU 7    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">   TPU 4    </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">   TPU 5    </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">            </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">            </span>\n",
       "<span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">            </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">            </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;2;255;255;255;48;2;57;59;121m            \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m            \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m            \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m            \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m            \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m            \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m   \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m    \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m   \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m    \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m   \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m    \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m   \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m    \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m            \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m            \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m            \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m            \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m            \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m            \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m            \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m            \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m            \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m            \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m            \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m            \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m            \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m   \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m    \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m   \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m    \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m   \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m    \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m   \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m    \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m            \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m            \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m            \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m            \u001b[0m\n",
       "\u001b[38;2;0;0;0;48;2;231;203;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m            \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m            \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m            \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "arr_sharded = jax.device_put(arr, sharding)\n",
    "\n",
    "print(arr_sharded)\n",
    "jax.debug.visualize_array_sharding(arr_sharded)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UEObolTqw4pp"
   },
   "source": [
    "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
    "\n",
    "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
    "\n",
    "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aKNeOHTJnqmS",
    "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pinned_host\n",
      "device\n"
     ]
    }
   ],
   "source": [
    "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
    "s_dev = s_host.with_memory_kind('device')\n",
    "arr_host = jax.device_put(arr, s_host)\n",
    "arr_dev = jax.device_put(arr, s_dev)\n",
    "print(arr_host.sharding.memory_kind)\n",
    "print(arr_dev.sharding.memory_kind)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jDHYnVqHwaST"
   },
   "source": [
    "## 1. Automatic parallelism via `jit`\n",
    "\n",
    "Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
    "\n",
    "The XLA compiler behind `jit` includes heuristics for optimizing computations across multiple devices.\n",
    "In the simplest of cases, those heuristics boil down to *computation follows data*.\n",
    "\n",
    "To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}`jax.jit`-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "outputId": "de46f86a-6907-49c8-f36c-ed835e78bc3d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shardings match: True\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def f_elementwise(x):\n",
    "  return 2 * jnp.sin(x) + 1\n",
    "\n",
    "result = f_elementwise(arr_sharded)\n",
    "\n",
    "print(\"shardings match:\", result.sharding == arr_sharded.sharding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.\n",
    "\n",
    "Here, you sum along the leading axis of `x`, and visualize how the result values are stored across multiple devices (with {func}`jax.debug.visualize_array_sharding`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "outputId": "90c3b997-3653-4a7b-c8ff-12a270f11d02"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\"> TPU 0,6 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\"> TPU 1,7 </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a\"> TPU 2,4 </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #b5cf6b\"> TPU 3,5 </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a\">         </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #b5cf6b\">         </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,6\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1,7\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2,4\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3,5\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m         \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m         \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m         \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m         \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[48. 52. 56. 60. 64. 68. 72. 76.]\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def f_contract(x):\n",
    "  return x.sum(axis=0)\n",
    "\n",
    "result = f_contract(arr_sharded)\n",
    "jax.debug.visualize_array_sharding(result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q4N5mrr9i_ki"
   },
   "source": [
    "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
    "\n",
    "### 1.1 Sharding transformation between memory types\n",
    "\n",
    "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
    "\n",
    "#### Example 1: Pinned host to device memory\n",
    "\n",
    "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PXu3MhafyRHo",
    "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1.  2.  3.  4.  5.  6.  7.]\n",
      " [ 8.  9. 10. 11. 12. 13. 14. 15.]\n",
      " [16. 17. 18. 19. 20. 21. 22. 23.]\n",
      " [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
      "device\n"
     ]
    }
   ],
   "source": [
    "f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
    "out_dev = f(arr_host)\n",
    "print(out_dev)\n",
    "print(out_dev.sharding.memory_kind)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LuYFqpcBySiX"
   },
   "source": [
    "#### Example 2: Device to pinned_host memory\n",
    "\n",
    "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "qLsgNlKfybRw",
    "outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1.  2.  3.  4.  5.  6.  7.]\n",
      " [ 8.  9. 10. 11. 12. 13. 14. 15.]\n",
      " [16. 17. 18. 19. 20. 21. 22. 23.]\n",
      " [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
      "pinned_host\n"
     ]
    }
   ],
   "source": [
    "g = jax.jit(lambda x: x, out_shardings=s_host)\n",
    "out_host = g(arr_dev)\n",
    "print(out_host)\n",
    "print(out_host.sharding.memory_kind)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7BGD31-owaSU"
   },
   "source": [
    "## 2. Semi-automated sharding with constraints\n",
    "\n",
    "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
    "\n",
    "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">  TPU 0  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">  TPU 1  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">  TPU 2  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">  TPU 3  </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">  TPU 6  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">  TPU 7  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">  TPU 4  </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">  TPU 5  </span>\n",
       "<span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #d6616b\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8ca252\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6\">         </span><span style=\"color: #000000; text-decoration-color: #000000; background-color: #e7cb94\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #6b6ecf\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #a55194\">         </span><span style=\"color: #ffffff; text-decoration-color: #ffffff; background-color: #8c6d31\">         </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;2;255;255;255;48;2;57;59;121m  \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m  \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m  \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m  \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m  \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m  \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m  \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m  \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m  \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m  \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m  \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m  \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m  \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m  \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m  \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m  \u001b[0m\n",
       "\u001b[38;2;255;255;255;48;2;57;59;121m         \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m         \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m         \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m         \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m         \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m         \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m         \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m         \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[48. 52. 56. 60. 64. 68. 72. 76.]\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def f_contract_2(x):\n",
    "  out = x.sum(axis=0)\n",
    "  sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
    "  return jax.lax.with_sharding_constraint(out, sharding)\n",
    "\n",
    "result = f_contract_2(arr_sharded)\n",
    "jax.debug.visualize_array_sharding(result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This gives you a function with the particular output sharding you'd like.\n",
    "\n",
    "## 3. Manual parallelism with `shard_map`\n",
    "\n",
    "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n",
    "\n",
    "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n",
    "\n",
    "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n",
    "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n",
    "\n",
    "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "outputId": "435c32f3-557a-4676-c11b-17e6bab8c1e2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 1.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,\n",
       "       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,\n",
       "       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,\n",
       "        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,\n",
       "        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,\n",
       "        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,\n",
       "       -0.97606325,  0.19192469], dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from jax.experimental.shard_map import shard_map\n",
    "mesh = jax.make_mesh((8,), ('x',))\n",
    "\n",
    "f_elementwise_sharded = shard_map(\n",
    "    f_elementwise,\n",
    "    mesh=mesh,\n",
    "    in_specs=P('x'),\n",
    "    out_specs=P('x'))\n",
    "\n",
    "arr = jnp.arange(32)\n",
    "f_elementwise_sharded(arr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function you write only \"sees\" a single batch of the data, which you can check by printing the device local shape:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "outputId": "99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "global shape: x.shape=(32,)\n",
      "device local shape: x.shape=(4,)\n"
     ]
    }
   ],
   "source": [
    "x = jnp.arange(32)\n",
    "print(f\"global shape: {x.shape=}\")\n",
    "\n",
    "def f(x):\n",
    "  print(f\"device local shape: {x.shape=}\")\n",
    "  return x * 2\n",
    "\n",
    "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because each of your functions only \"sees\" the device-local part of the data, it means that aggregation-like functions require some extra thought.\n",
    "\n",
    "For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "outputId": "1e9a45f5-5418-4246-c75b-f9bc6dcbbe72"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(x):\n",
    "  return jnp.sum(x, keepdims=True)\n",
    "\n",
    "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your function `f` operates separately on each shard, and the resulting summation reflects this.\n",
    "\n",
    "If you want to sum across shards, you need to explicitly request it using collective operations like {func}`jax.lax.psum`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "outputId": "4fd29e80-4fee-42b7-ff80-29f9887ab38d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(496, dtype=int32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(x):\n",
    "  sum_in_shard = x.sum()\n",
    "  return jax.lax.psum(sum_in_shard, 'x')\n",
    "\n",
    "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`).\n",
    "\n",
    "## Comparing the three approaches\n",
    "\n",
    "With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.\n",
    "\n",
    "Start by defining your canonical function like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "1TdhfTsoiqS1"
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def layer(x, weights, bias):\n",
    "  return jax.nn.sigmoid(x @ weights + bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "outputId": "f3007fe4-f6f3-454e-e7c5-3638de484c0a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "rng = np.random.default_rng(0)\n",
    "\n",
    "x = rng.normal(size=(32,))\n",
    "weights = rng.normal(size=(32, 4))\n",
    "bias = rng.normal(size=(4,))\n",
    "\n",
    "layer(x, weights, bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n",
    "\n",
    "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "outputId": "80be899e-8dbc-4bfc-acd2-0f3d554a0aa5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mesh = jax.make_mesh((8,), ('x',))\n",
    "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
    "\n",
    "x_sharded = jax.device_put(x, sharding)\n",
    "weights_sharded = jax.device_put(weights, sharding)\n",
    "\n",
    "layer(x_sharded, weights_sharded, bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@jax.jit\n",
    "def layer_auto(x, weights, bias):\n",
    "  x = jax.lax.with_sharding_constraint(x, sharding)\n",
    "  weights = jax.lax.with_sharding_constraint(weights, sharding)\n",
    "  return layer(x, weights, bias)\n",
    "\n",
    "layer_auto(x, weights, bias)  # pass in unsharded inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "outputId": "568d1c85-39a7-4dba-f09a-0e4f7c2ea918"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from functools import partial\n",
    "\n",
    "@jax.jit\n",
    "@partial(shard_map, mesh=mesh,\n",
    "         in_specs=(P('x'), P('x', None), P(None)),\n",
    "         out_specs=P(None))\n",
    "def layer_sharded(x, weights, bias):\n",
    "  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)\n",
    "\n",
    "layer_sharded(x, weights, bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "This tutorial serves as a brief introduction of sharded and parallel computation in JAX.\n",
    "\n",
    "To learn about each SPMD method in-depth, check out these docs:\n",
    "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n",
    "- {doc}`../notebooks/shard_map`"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "gpuType": "V28",
   "provenance": [],
   "toc_visible": true
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
